{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyhanlp import HanLP\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "# 读取原始数据集分词预处理 并保存词典\n",
    "def read_toutiao_dataset(data_path, save_vocab_path):\n",
    "    with open(data_path, \"r\", encoding=\"utf8\") as fo:\n",
    "        all_lines = fo.readlines()\n",
    "    datas, labels = [], []\n",
    "    word_vocabs = {}\n",
    "    for line in tqdm(all_lines):\n",
    "        content_words = []\n",
    "        category, content = line.strip().split(\"_!_\")\n",
    "        for term in HanLP.segment(content):\n",
    "            if term.word not in word_vocabs:\n",
    "                word_vocabs[term.word] = len(word_vocabs)+1\n",
    "            content_words.append(term.word)\n",
    "        datas.append(content_words)\n",
    "        labels.append(category)\n",
    "    with open(save_vocab_path, \"w\", encoding=\"utf8\") as fw:\n",
    "        for word, index in word_vocabs.items():\n",
    "            fw.write(word+\"\\n\")\n",
    "    return datas, labels\n",
    "\n",
    "# 读取词典 生成词-索引对应关系, 其中special_words = ['<PAD>', '<UNK>']\n",
    "def read_word_vocabs(save_vocab_path, special_words):\n",
    "    with open(save_vocab_path, \"r\", encoding=\"utf8\") as fo:\n",
    "        word_vocabs = [word.strip() for word in fo]\n",
    "    word_vocabs = special_words + word_vocabs\n",
    "    idx2vocab = {idx: char for idx, char in enumerate(word_vocabs)}\n",
    "    vocab2idx = {char: idx for idx, char in idx2vocab.items()}\n",
    "    return idx2vocab, vocab2idx\n",
    "\n",
    "# 把预处理过的数据索引化 即变成词编号序列\n",
    "def process_dataset(datas, labels, category2idx, vocab2idx):\n",
    "    new_datas, new_labels = [], []\n",
    "    for data, label in zip(datas, labels):\n",
    "        index_data = [vocab2idx[word] if word in vocab2idx else vocab2idx['<UNK>'] for word in data]\n",
    "        index_label = category2idx[label]\n",
    "        new_datas.append(index_data)\n",
    "        new_labels.append(index_label)\n",
    "    return new_datas, new_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████| 382688/382688 [12:34<00:00, 507.49it/s]\n"
     ]
    }
   ],
   "source": [
    "data_path = \"./data/toutiao_news_dataset.txt\"\n",
    "save_vocab_path = \"./data/word_vocabs.txt\"\n",
    "special_words = ['<PAD>', '<UNK>']\n",
    "category_lists = [\"民生故事\",\"文化\",\"娱乐\",\"体育\",\"财经\",\"房产\",\"汽车\",\"教育\",\"科技\",\"军事\",\n",
    "                \"旅游\",\"国际\",\"证券股票\",\"农业\",\"电竞游戏\"]\n",
    "\n",
    "category2idx = {cate: idx for idx, cate in enumerate(category_lists)}\n",
    "idx2category = {idx: cate for idx, cate in enumerate(category_lists)}\n",
    "\n",
    "datas, labels = read_toutiao_dataset(data_path, save_vocab_path)\n",
    "idx2vocab, vocab2idx = read_word_vocabs(save_vocab_path, special_words)\n",
    "all_datas, all_labels = process_dataset(datas, labels, category2idx, vocab2idx)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['京城', '最', '值得', '你', '来', '场', '文化', '之旅', '的', '博物馆'] 文化\n",
      "<PAD> <UNK> 你\n",
      "民生故事 文化 娱乐\n"
     ]
    }
   ],
   "source": [
    "print(datas[0], labels[0])\n",
    "print(idx2vocab[0], idx2vocab[1], idx2vocab[5])\n",
    "print(idx2category[0], idx2category[1], idx2category[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy\n",
    "import keras\n",
    "from keras import backend as K\n",
    "from keras import activations\n",
    "from keras.engine.topology import Layer\n",
    "from keras.preprocessing import sequence\n",
    "from keras.models import Sequential\n",
    "from keras.models import Model\n",
    "from keras.layers import Input, Dense, Embedding, LSTM, Bidirectional\n",
    "K.clear_session()\n",
    "\n",
    "class AttentionLayer(Layer):\n",
    "    def __init__(self, attention_size=None, **kwargs):\n",
    "        self.attention_size = attention_size\n",
    "        super(AttentionLayer, self).__init__(**kwargs)\n",
    "        \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config['attention_size'] = self.attention_size\n",
    "        return config\n",
    "        \n",
    "    def build(self, input_shape):\n",
    "        assert len(input_shape) == 3\n",
    "        \n",
    "        self.time_steps = input_shape[1]\n",
    "        hidden_size = input_shape[2]\n",
    "        if self.attention_size is None:\n",
    "            self.attention_size = hidden_size\n",
    "            \n",
    "        self.W = self.add_weight(name='att_weight', shape=(hidden_size, self.attention_size),\n",
    "                                initializer='uniform', trainable=True)\n",
    "        self.b = self.add_weight(name='att_bias', shape=(self.attention_size,),\n",
    "                                initializer='uniform', trainable=True)\n",
    "        self.V = self.add_weight(name='att_var', shape=(self.attention_size,),\n",
    "                                initializer='uniform', trainable=True)\n",
    "        super(AttentionLayer, self).build(input_shape)\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        self.V = K.reshape(self.V, (-1, 1))\n",
    "        H = K.tanh(K.dot(inputs, self.W) + self.b)\n",
    "        score = K.softmax(K.dot(H, self.V), axis=1)\n",
    "        outputs = K.sum(score * inputs, axis=1)\n",
    "        return outputs\n",
    "    \n",
    "    def compute_output_shape(self, input_shape):\n",
    "        return input_shape[0], input_shape[2]\n",
    "    \n",
    "    \n",
    "def create_classify_model(max_len, vocab_size, embedding_size, hidden_size, attention_size, class_nums):\n",
    "    inputs = Input(shape=(max_len,), dtype='int32')\n",
    "    x = Embedding(vocab_size, embedding_size)(inputs)\n",
    "    x = Bidirectional(LSTM(hidden_size, dropout=0.2, return_sequences=True))(x)\n",
    "    x = AttentionLayer(attention_size=attention_size)(x)\n",
    "    outputs = Dense(class_nums, activation='softmax')(x)\n",
    "    model = Model(inputs=inputs, outputs=outputs)\n",
    "    model.summary() # 输出模型结构和参数数量\n",
    "    return model\n",
    "\n",
    "MAX_LEN = 30\n",
    "VOCAB_SIZE = len(vocab2idx)\n",
    "EMBEDDING_SIZE = 100\n",
    "HIDDEN_SIZE = 64\n",
    "ATT_SIZE = 50\n",
    "CLASS_NUMS = len(category2idx)\n",
    "BATCH_SIZE = 64\n",
    "EPOCHS = 20\n",
    "\n",
    "all_datas = all_datas[:10000]\n",
    "all_labels = all_labels[:10000]\n",
    "\n",
    "count = len(all_labels)\n",
    "rate1, rate2 = 0.8, 0.9 # train-0.8, test-0.1, dev-0.1\n",
    "# padding the data\n",
    "new_datas = sequence.pad_sequences(all_datas, maxlen=MAX_LEN)\n",
    "new_labels = keras.utils.to_categorical(all_labels, CLASS_NUMS)\n",
    "# split all data to train, test and dev\n",
    "x_train, y_train = new_datas[:int(count*rate1)], new_labels[:int(count*rate1)]\n",
    "x_test, y_test = new_datas[int(count*rate1):int(count*rate2)], new_labels[int(count*rate1):int(count*rate2)]\n",
    "x_dev, y_dev = new_datas[int(count*rate2):], new_labels[int(count*rate2):]\n",
    "\n",
    "# create model\n",
    "model = create_classify_model(MAX_LEN, VOCAB_SIZE, EMBEDDING_SIZE, HIDDEN_SIZE, ATT_SIZE, CLASS_NUMS)\n",
    "# loss and optimizer\n",
    "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "# train model\n",
    "model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=(x_test, y_test))\n",
    "# test model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "score, acc = model.evaluate(x_dev, y_dev, batch_size=BATCH_SIZE)\n",
    "print('Test score:', score)\n",
    "print('Test accuracy:', acc)\n",
    "# save model\n",
    "model.save(\"./model/news_classify_model.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.00109552 0.00387199 0.21370623 0.00758541 0.5126246  0.00125148\n",
      "  0.00908471 0.01694475 0.01339941 0.00030947 0.00917827 0.00754309\n",
      "  0.08456795 0.1168716  0.00196562]]\n",
      "4 财经\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from pyhanlp import HanLP\n",
    "np.set_printoptions(suppress=True)\n",
    "\n",
    "maxlen = 30\n",
    "content = \"网友数博会：大数据与实体经济融合发展，贵州焕发新动力\"\n",
    "content_words = [term.word for term in HanLP.segment(content)]\n",
    "sent2id = [vocab2idx[word] if word in vocab2idx else vocab2idx['<UNK>'] for word in content_words]\n",
    "sent2id_new = np.array([sent2id[:maxlen] + [0] * (maxlen-len(sent2id))])\n",
    "\n",
    "y_pred = model.predict(sent2id_new)\n",
    "print(y_pred)\n",
    "y_label = np.argmax(y_pred[0])\n",
    "print(y_label, idx2category[y_label])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.00109552 0.00387199 0.21370623 0.00758541 0.5126246  0.00125148\n",
      "  0.00908471 0.01694475 0.01339941 0.00030947 0.00917827 0.00754309\n",
      "  0.08456795 0.1168716  0.00196562]]\n",
      "[('财经', 0.5126246), ('娱乐', 0.21370623), ('农业', 0.1168716), ('证券股票', 0.08456795), ('教育', 0.016944751), ('科技', 0.01339941), ('旅游', 0.00917827), ('汽车', 0.009084707), ('体育', 0.0075854124), ('国际', 0.007543085), ('文化', 0.0038719943), ('电竞游戏', 0.0019656185), ('房产', 0.0012514826), ('民生故事', 0.0010955166), ('军事', 0.0003094678)]\n"
     ]
    }
   ],
   "source": [
    "from keras.models import load_model\n",
    "import numpy as np\n",
    "from pyhanlp import HanLP\n",
    "np.set_printoptions(suppress=True)\n",
    "\n",
    "save_vocab_path = \"./data/word_vocabs.txt\"\n",
    "model_path = \"./model/news_classify_model.h5\"\n",
    "special_words = ['<PAD>', '<UNK>']\n",
    "category_lists = [\"民生故事\",\"文化\",\"娱乐\",\"体育\",\"财经\",\"房产\",\"汽车\",\"教育\",\"科技\",\"军事\",\n",
    "                \"旅游\",\"国际\",\"证券股票\",\"农业\",\"电竞游戏\"]\n",
    "maxlen = 30\n",
    "ATT_SIZE = 50\n",
    "\n",
    "category2idx = {cate: idx for idx, cate in enumerate(category_lists)}\n",
    "idx2category = {idx: cate for idx, cate in enumerate(category_lists)}\n",
    "idx2vocab, vocab2idx = read_word_vocabs(save_vocab_path, special_words)\n",
    "\n",
    "content = \"网友数博会：大数据与实体经济融合发展，贵州焕发新动力\"\n",
    "content_words = [term.word for term in HanLP.segment(content)]\n",
    "sent2id = [vocab2idx[word] if word in vocab2idx else vocab2idx['<UNK>'] for word in content_words]\n",
    "sent2id_new = np.array([sent2id[:maxlen] + [0] * (maxlen-len(sent2id))])\n",
    "\n",
    "model = load_model(model_path, custom_objects={'AttentionLayer': AttentionLayer(ATT_SIZE)}, compile=False)\n",
    "y_pred = model.predict(sent2id_new)\n",
    "print(y_pred)\n",
    "result = {}\n",
    "for idx, pred in enumerate(y_pred[0]):\n",
    "    result[idx2category[idx]] = pred\n",
    "result_sorted = sorted(result.items(), key=lambda item: item[1], reverse=True)\n",
    "print(result_sorted)\n",
    "# y_label = np.argmax(y_pred[0])\n",
    "# print(y_label, idx2category[y_label])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
